import pickle
import numpy as np
import symmetries
from inspect import getmembers, isfunction
import itertools
import torch

folder = "pickles"
n_meas = 16
magenta_size = 256

def getPrototypes(adj_mat):
    adj_mat = np.sum(adj_mat, axis=2)
    xs = []
    for i in range(n_meas):
        prot = (np.argmax(adj_mat[i,16:]))
        prot_in_meas = np.argmax(adj_mat[:,prot])
        xs.append(prot_in_meas)
    y = np.array(xs)
    return y

meas = pickle.load(open(folder + "/meas16.pcl", "rb"))
simple_mats = pickle.load(open(folder + "/simple_mats" + str(n_meas) + ".pcl", "rb"))
inds = pickle.load(open(folder + "/inds" + str(n_meas) + ".pcl", "rb"))
meas = [meas[i] for i in inds]
adj_mats = pickle.load(open(folder + "/adj_mats" + str(n_meas) + ".pcl", "rb"))
assert(len(meas) == len(simple_mats))
prototypes = list(map(getPrototypes, adj_mats))
all_rewards_vec = pickle.load(open(folder + "/all_rewards_vec" + str(n_meas) + ".pcl", "rb"))
magents = pickle.load(open(folder + "/analyzedmagents.pcl", "rb"))

new_magents = []
for graph_mag in magents:
    new_magents.append([])
    for bar in graph_mag:
        try:
            new_magents[-1].append(bar[0,:])
        except:
            new_magents[-1].append(bar)
magents = new_magents

all_nodes = []
all_edge_inds = []
all_edge_attrs = []

functions_list = [o for o in getmembers(symmetries) if isfunction(o[1])]
functions_name_list = [i[0] for i in functions_list if i[0] != "mod12Same"]



num_symmetries = len(functions_name_list)
num_total_edge_attrs = 2*num_symmetries + 1 + 1 + 1



for z in range(len(magents)):
    edge_attrs = []
    edge_inds = []
    rewards_vec = all_rewards_vec[z]
    sorted_prototypes = sorted(list(set(prototypes[z])))

    x = np.zeros((len(set(prototypes[z])) + len(magents[z]), magenta_size))

    for h in range(len(set(prototypes[z]))):
        try:
            x[h] = magents[z][prototypes[z][h]]
        except:
            print("error")
            x[h] = magents[z][prototypes[z][0]]
        x[len(set(prototypes[z])) + h] = magents[z][h]

    edge_inds = []
    for h in range(1, len(magents[z])):
        edge_inds.append([len(set(prototypes[z])) + h - 1, len(set(prototypes[z])) + h])
        edge_attr = np.zeros(num_total_edge_attrs)
        edge_attr[2*num_symmetries + 1] = 1
        edge_attrs.append(edge_attr)
    for h in range(1, 16):

            edge_attr = np.zeros(num_total_edge_attrs)
            try:
                for (k_ind, k) in enumerate(functions_name_list):
                    if k in rewards_vec[h][prototypes[z][h]]:
                        edge_attr[k_ind] = 1
            except:
                print("error")
            edge_attrs.append(edge_attr)
            edge_inds.append([prototypes[z][h], len(set(prototypes[z])) + h])
        
    rewards_vec = all_rewards_vec[z]
    for h in range(len(magents[z])):
        neighbors = [k for k in range(h - 3, h) if k >= 0 and k < len(magents[z])]
        for neighbor in neighbors:
            edge_inds.append([h + len(set(prototypes[z])), neighbor + len(set(prototypes[z]))])
            edge_attr = np.zeros(num_total_edge_attrs)
            try:

                for (k_ind, k) in enumerate(functions_name_list):
                    if k in rewards_vec[h][neighbor]:
                        edge_attr[num_symmetries + k_ind] = 1
            except:
                print("error")
            edge_attr[2*num_symmetries + 2]	= 1
            edge_attrs.append(edge_attr)
    all_nodes.append(x)
    all_edge_inds.append(edge_inds)
    all_edge_attrs.append(edge_attrs)
pickle.dump((all_nodes, all_edge_inds, all_edge_attrs), open("pickles/graphelements.pcl", "wb"))
